from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
import torch
from peft import PeftModel
import json
import tqdm
from eval import eval_file
import argparse

class GSM8kTestDataset(Dataset):
    def __init__(self, data, add_special_token=False):
        self.data = data
        self.add_special_token = add_special_token
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if self.add_special_token:
            return f'<|im_start|>{self.data[idx]["question"]}<|im_end|>'
        else:
            return self.data[idx]["question"]


class GSM8kChatDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        question = self.data[idx]["question"]

        instruction = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": "Please solve the following problem step by step."},
                {"role": "user", "content": question},
            ],
            tokenize = False,
            add_generation_prompt=True,
        )
        return instruction


def collate_fn(batch, tokenizer):
    return tokenizer(
        batch,
        padding=True,
        truncation=True,
        max_length=1024,  # 控制输入长度
        return_tensors="pt"
    )


def inference(model_path, base_model_name, data_path='./gsm8k/test.jsonl', do_sample=False, temperature=0.7, data_num=None, use_template=False, add_special_token=False):
    base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # For some models like Llama/Mistral
    if model_path is not None:
        # use peft
        model = PeftModel.from_pretrained(base_model, model_path)
    else:
        model = base_model

    # load data
    with open(data_path, 'r') as f:
        data = f.readlines()
        data = [json.loads(d) for d in data]

    if data_num is not None:
        data = data[:data_num]

    if use_template:
        dataset = GSM8kChatDataset(data, tokenizer)
    else:
        dataset = GSM8kTestDataset(data, add_special_token=add_special_token)
    dataloader = DataLoader(dataset, batch_size=16, collate_fn=lambda x: collate_fn(x, tokenizer))

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    
    inference_results = []


    for batch in tqdm.tqdm(dataloader):
        inputs = {k: v.to(device) for k, v in batch.items()}
        generation_config = {
            "max_new_tokens": 2048,          # 控制生成长度
            "do_sample": do_sample,
            "temperature": temperature,
            "pad_token_id": tokenizer.eos_token_id,
            "use_cache": True               # 启用KV缓存
        }
        with torch.no_grad():
            # print(inputs)
            outputs = model.generate(**inputs, **generation_config)
            batch_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            inference_results.extend(batch_texts)

    if model_path is not None:
        output_dir = f'{model_path}/inference_results'
    else:
        output_dir = f'models/{base_model_name}/inference_results'

    if not do_sample:
        output_dir += '-greedy'
    else:
        output_dir += '-sample-T' + str(temperature)
    if data_num is not None:
        output_dir += f'-{data_num}'
    if use_template:
        output_dir += '-chat_template'
    else:
        if add_special_token:
            output_dir += '-special_token'

    with open(f'{output_dir}.jsonl', 'w') as f:
        for text, item in zip(inference_results, data):
            item["generated"] = text
            f.write(json.dumps(item) + '\n')

    eval_file(f'{output_dir}.jsonl')


if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--model_path', type=str, default='./models/Qwen/Qwen2.5-1.5B-LoRA-RL-expert-Qwen2.5-1.5B-0129-124535')
    # parser.add_argument('--base_model_name', type=str, default='Qwen/Qwen2.5-1.5B')
    # parser.add_argument('--do_sample', type=bool, default=False)
    # parser.add_argument('--temperature', type=float, default=0.7)
    # args = parser.parse_args()

    # model_name = 'Qwen/Qwen2.5-1.5B'
    # base_model_name = 'Qwen/Qwen2.5-1.5B'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-LoRA-SFT-cos-r32-alp128-q-k-o-expert-Qwen2.5-1.5B-0314-111847/last'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-LoRA-SFT-cos-r16-alp64-expert-Qwen2.5-1.5B-0311-144411/last'
    # inference(model_path=None, base_model_name=base_model_name, data_num=100)


    # 0315
    # model_path = 'models/Qwen/Qwen2.5-1.5BV2-LoRA-SFT-cos-r4-alp16-q-k-v-expert-Qwen2.5-1.5B-0315-175853/best'
    # inference(model_path=model_path, base_model_name=base_model_name, data_num=None, add_special_token=True)


    # 0316
    # model_path = 'models/Qwen/Qwen2.5-1.5BV3-LoRA-SFT-cos-r4-alp16-q-k-v-o-up-gate-down-expert-Qwen2.5-1.5B-0316-010212/last'
    # inference(model_path=model_path, base_model_name=base_model_name, data_num=100, add_special_token=True)

    # run_name = 'Qwen/Qwen2.5-1.5B-LoRA-SFT-cos-expert-Qwen2.5-1.5B-0310-141503'
    # run_name = 'Qwen/Qwen2.5-1.5B-LoRA-SFT-expert-Qwen2.5-1.5B-0310-135452'
    # inference(model_path=f"./models/{run_name}", base_model_name=model_name, data_num=500) # no sample
    # inference(model_path=f"./models/{run_name}", base_model_name=model_name, do_sample=True, temperature=0.7, data_num=500) # sample

    #0321
    base_model_name = 'Qwen/Qwen2.5-1.5B-Instruct'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2025-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5-0320-235016/best'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2025-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5-0320-201925'
    # inference(model_path=model_path, base_model_name=base_model_name, data_num=None, use_template=True)
    # inference(model_path=None, base_model_name=base_model_name, data_num=None, use_template=True)
    
    
    # 0322
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5_new-0322-004813/best'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-cos-b4-r4-alp16-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5_new-0322-005303/best'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-cos-b4-r16-alp64-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5_new-0322-005330/best'


    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-cos-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-imperfect_threshold_5_new-0322-184215/last'
    # inference(model_path=model_path, base_model_name=base_model_name, data_num=None, use_template=True)
    
    # mix 0.8
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-mix0.8-q-k-v-o-wte-lm_head-0325-112357/ckpt-400'
    # inference(model_path=model_path, base_model_name=base_model_name, data_num=None, use_template=True)
    # mix 0.6
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-mix0.6-q-k-v-o-wte-lm_head-0324-170813/ckpt-2000'
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-mix0.6-q-k-v-o-wte-lm_head-0325-164854/ckpt-400'
    # eval_file(f'{output_dir}.jsonl')
    # mix 0.4
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-mix0.4-q-k-v-o-wte-lm_head-0324-170815/ckpt-4000'

    # mix 0.2
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-mix0.2-q-k-v-o-wte-lm_head-0324-170758/ckpt-7000'
    # inference(model_path=model_path, base_model_name=base_model_name, data_num=None, use_template=True)


    # test
    # model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-mix0.8-q-k-v-o-wte-lm_head-0327-000615/ckpt-3000'
    model_path = 'models/Qwen/Qwen2.5-1.5B-Instruct-template-LoRA-seed2026-SFT-b4-r8-alp32-lr5e-07-q-k-v-o-wte-lm_head-expert_threshold_5_new-0327-185940/ckpt-2000'
    inference(model_path=model_path, base_model_name=base_model_name, data_num=None, use_template=True)
# if __name__ == '__main__':
#     base_model_name = 'Qwen/Qwen2.5-1.5B'

#     # model_path = './models/Qwen/Qwen2.5-1.5B-LoRA-RL-expert-Qwen2.5-1.5B-0129-124535'
#     model_path = './models/Qwen/Qwen2.5-1.5B-LoRA-SFT-expert-Qwen2.5-1.5B-0129-073418'
#     # model_path = './models/Qwen/Qwen2.5-1.5B-LoRA-CQL-a10-imperfect-Qwen2.5-1.5B-0129-193242'

#     base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
#     tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token  # For some models like Llama/Mistral

#     # use peft
#     model = PeftModel.from_pretrained(base_model, model_path)
#     # model = base_model
#     # model_path = './models/Qwen/Qwen2.5-1.5B'
    
#     # load data
#     data_path = './gsm8k/test.jsonl'
#     with open(data_path, 'r') as f:
#         data = f.readlines()
#         data = [json.loads(d) for d in data]
    
#     dataset = GSM8kTestDataset(data)
#     dataloader = DataLoader(dataset, batch_size=8, collate_fn=lambda x: collate_fn(x, tokenizer))

#     device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#     model.to(device)
#     model.eval()

#     ############################
#     # configuation for generation
#     do_sample = False # False
#     temperature = 0.7
#     ############################

#     inference_results = []

#     for batch in tqdm.tqdm(dataloader):
#         inputs = {k: v.to(device) for k, v in batch.items()}
#         generation_config = {
#             "max_new_tokens": 1024,          # 控制生成长度
#             "do_sample": do_sample,
#             "temperature": temperature,
#             # "top_p": 0.9,
#             # "repetition_penalty": 1.2,
#             "pad_token_id": tokenizer.eos_token_id,
#             "use_cache": True               # 启用KV缓存
#         }
#         with torch.no_grad():
#             outputs = model.generate(**inputs, **generation_config)
#             batch_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
#             inference_results.extend(batch_texts)
    
#     output_dir = f'{model_path}/inference_results'

#     if not do_sample:
#         output_dir += '-greedy'
#     else:
#         output_dir += '-sample-T' + str(temperature)

#     with open(f'{output_dir}.jsonl', 'w') as f:
#         for text, item in zip(inference_results, data):
#             item["generated"] = text
#             f.write(json.dumps(item) + '\n')
        